import queue
import torch
from typing import List

class WeightGradStore:

    cache = []
    weight_grad_queue = queue.Queue()
    combine_bw = False

    @classmethod
    def set_combine_bw(cls, combine_bw):
        # For the following backward pass, combine W with B and skip next W
        cls.combine_bw = combine_bw

    @classmethod
    def split_bw(cls):
        # For the following backward pass, combine W with B and skip next W
        return not cls.combine_bw

    @classmethod
    def put(cls, inputs, func, pos_func = None):
        if cls.combine_bw:
            func(*inputs)
            if pos_func is not None:
                if isinstance(inputs[0], List):
                    pos_func(*inputs[0])
                else:
                    pos_func(inputs[0])
            return
        # Store the weight gradient computation of linear layers.
        cls.cache.append((inputs, func, pos_func))

    @classmethod
    def flush(cls):
        if cls.combine_bw:
            cls.combine_bw = False
            return
        # Collect all stored computations during backward as a W.
        cls.weight_grad_queue.put(cls.cache)
        cls.cache = []

    @classmethod
    def pop(cls):
        assert not cls.combine_bw
        # Execute a single W.
        assert cls.weight_grad_queue.qsize() > 0
        stored_grads = cls.weight_grad_queue.get()
        with torch.enable_grad():
            for inputs, func, pos_func in stored_grads:
                func(*inputs)
                if pos_func is not None:
                    if isinstance(inputs[0], List):
                        pos_func(*inputs[0])
                    else:
                        pos_func(inputs[0])